/*
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License version 2
 *   as published by the Free Software Foundation.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 *
 *   Copyright (C) 2008  Benjamin Segovia <segovia.benjamin@gmail.com>
 */

#ifdef _WIN32
    #define WINDOWS_LEAN_AND_MEAN
    #define NOMINMAX
    #include <windows.h>
#pragma warning(disable:4996)
#endif

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <assert.h>

#include "specifics.h"
#include "sys_clock.h"
#include "sys_log.h"
#include "sys_map.h"
#include "bvhlib.h"
#include "kdlib.h"
#include "kdlib_blockifier.h"
#include "rt_kdtree.h"
#include "rt_camera.h"
#include "cuda.h"
#include "cuda_runtime_api.h"
#include "cutil.h"
#include "sys_mem.h"

#define GLEW_STATIC
#include "glew.h"

#if defined(__APPLE__) || defined(MACOSX)
    #include <GLUT/glut.h>
#else
    #include <GL/glut.h>
#endif

#include <cutil.h>
#include <cutil_gl_error.h>

unsigned int window_width = 1024;
unsigned int window_height = 512;
unsigned int image_w = 1024;
unsigned int image_h = 512;

/* Pbo variables */
GLuint pbo_dest;
unsigned int size_tex_data;
unsigned int num_texels;
unsigned int num_values;

/* Render target */
GLuint tex_screen;

/* The camera */
rt::camera_t cam;

/* The current display mode */
enum display_mode_t {shadowed = 0, unshadowed, noshading };
uint32_t display_mode = unshadowed;

/* The light position */
float lpos[3];

/* The normals of the triangles */
std::vector<vec_t> normals;

/* The bounding box of the scene */
aabb_t aabb;

/* The initial position */
vec_t eye(0.f, 0.5f, 2.4f);
float fov = 90.f;

/* The kd-tree we want to intersect on the GPU */
kdtree::descriptor_t app_kd_tree;

/* The device ptr of the nodes and triangles */
void *device_normal;
void *device_rt_tris;
void *device_kd_nodes;
void *device_kd_ids;

bool enable_cuda = true;

extern "C" void init_cuda(int argc, char **argv);
extern "C" void cuda_shadowed(int pbo_out, void *cam, void *aabb, float *pos, int w, int h);
extern "C" void cuda_unshadowed(int pbo_out, void *cam, void *aabb, float *pos, int w, int h);
extern "C" void cuda_noshading(int pbo_out, void *cam, void *aabb, float *pos, int w, int h);
extern "C" void pbo_register(int pbo);
extern "C" void pbo_unregister(int pbo);
extern "C" void node_texture_bind(const void *data, size_t size);
extern "C" void rt_tri_texture_bind(const void *data, size_t size);
extern "C" void tri_texture_bind(const void *data, size_t size);
extern "C" void id_texture_bind(const void *data, size_t size);
extern "C" void normal_texture_bind(const void *data, size_t size);

extern void run(int argc, char** argv);
extern CUTBoolean init_gl();
extern void create_pbo(GLuint* pbo);
extern void delete_pbo(GLuint* pbo);
extern void create_texture(GLuint* tex_name, unsigned int size_x, unsigned int size_y);
extern void delete_texture(GLuint* tex);
extern void display();
extern void idle();
extern void keyboard(unsigned char key, int x, int y);
extern void reshape(int w, int h);
extern void main_menu(int i);
extern void run(int argc, char** argv);
extern void fatal(const char * const reason);

/***************************************************************************//**
 * Compute the accelerated triangles
/******************************************************************************/
static bool_t
compute_tri_acc(
    const triangle_t &t, rt::wald_tri_t &w,
    const uint_t id, const uint_t mat_id)
{
    const vec_t &A(t.verts[0]), &B(t.verts[1]), &C(t.verts[2]);
    const vec_t b(C - A), c(B - A), N(b.cross(c));
    uint_t k = 0;
    for (uint_t i = 1; i < 3; ++i) k = fabsf(N[i]) > fabsf(N[k]) ? i : k;
    const uint_t u = (k+1)%3, v = (k+2)%3;
    const float denom = (b[u]*c[v] - b[v]*c[u]);
    const float krec = N[k];
    const float nu = N[u] / krec, nv = N[v] / krec, nd = N.dot(A) / krec;
    const float bnu =  b[u] / denom, bnv = -b[v] / denom;
    const float cnu =  c[v] / denom, cnv = -c[u] / denom;
    w.k = k;
    w.n_u = float(nu);
    w.n_v = float(nv);
    w.n_d = float(nd);
    w.vert_ku = float(A[u]);
    w.vert_kv = float(A[v]);
    w.b_nu = float(bnu);
    w.b_nv = float(bnv);
    w.c_nu = float(cnu);
    w.c_nv = float(cnv);
    w.id = id;
    w.matid = mat_id;
    return (krec == 0.) | (denom == 0.);
}

/***************************************************************************//**
 * Compile the custom triangles used for the intersections
/******************************************************************************/
static NOINLINE void
bake_intersection(
    const triangle_t * __restrict const tri,
    const uint_t tri_n, rt::wald_tri_t * __restrict acc)
{
    int deg = 0;
    for (uint_t tid=0; tid < tri_n; ++tid)
        deg += compute_tri_acc(tri[tid], acc[tid], tid, 0);
    sys::log("bake_intersection: %d triangles, %d degenerated.\n", tri_n, deg);
}

/***************************************************************************//**
 * Load the geometry and compile the kd-tree and the custom triangles
/******************************************************************************/
static NOINLINE int scene_compile(kdtree::descriptor_t &kd_tree, const char *file)
{
    sys::log("loading %s...\n", file);
    sys::map_t m;

    /* Load the file and compile the BVH tree */
    m.open(file);
    if (!m.is_mapped()) fatal("failed to mmap scene data.");
    const triangle_t * __restrict const soup = (const triangle_t * __restrict const) m.begin();
    const uint32_t tri_n = m.get_size<triangle_t>();
    kdlib_compile(soup, tri_n, kd_tree, aabb);
    kd_tree.acc = (rt::wald_tri_t *)sys::mem::allocate(tri_n*sizeof(rt::wald_tri_t));
    bake_intersection(soup, tri_n, (rt::wald_tri_t *)kd_tree.acc);

    /* Perform a block allocation of the kd-tree */
    kdtree::descriptor_t dst;
    kdlib::do_blockify(dst, kd_tree, 4);
    sys::mem::liberate((void *) kd_tree.acc);
    sys::mem::liberate((void *) kd_tree.ids);
    sys::mem::liberate((void *) kd_tree.root);
    kd_tree = dst;

    /* Allocate the per-triangle normal vectors */
    normals.resize(tri_n);
    for(uint32_t i = 0; i < tri_n; ++i) {
        const vec_t d0 = soup[i].verts[2] - soup[i].verts[0];
        const vec_t d1 = soup[i].verts[1] - soup[i].verts[0];
        normals[i] = d1.cross(d0);
        normals[i] = normals[i].normalize();
    }
    return tri_n;
}

/***************************************************************************//**
 * Main Program
/******************************************************************************/
int main(int argc, char** argv)
{
    run(argc, argv);
    return EXIT_SUCCESS;
}

/***************************************************************************//**
 * Run a simple test in CUDA
/******************************************************************************/
void run(int argc, char** argv)
{
    /* Load the data file and compile the bvh tree */
    scene_compile(app_kd_tree, "Data/FairyForestF160.ra2");
    init_cuda(argc, argv);

    /* Create GL context */
    glutInit(&argc, argv);
    glutInitDisplayMode(GLUT_RGBA | GLUT_ALPHA | GLUT_DOUBLE | GLUT_DEPTH);
    glutInitWindowSize(window_width, window_height);
    glutCreateWindow("CUDA RTRT");

    /* Initialize GL */
    if(CUTFalse == init_gl()) return;

    /* Init the timers */
    sys::laps_t::bootstrap();

    /* Register callbacks */
    glutDisplayFunc(display);
    glutKeyboardFunc(keyboard);
    glutReshapeFunc(reshape);
    glutIdleFunc(idle);

    /* Create menu */
    glutCreateMenu(main_menu);
    glutAddMenuEntry("Toggle CUDA processing [ ]", ' ');
    glutAddMenuEntry("With shadows", '1');
    glutAddMenuEntry("Without shadows", '2');
    glutAddMenuEntry("Without shading", '3');
    glutAddMenuEntry("Quit (esc)", '\033');
    glutAttachMenu(GLUT_RIGHT_BUTTON);

    /* Create pbo */
    create_pbo(&pbo_dest);

    /* Create texture for blitting onto the screen */
    create_texture(&tex_screen, image_w, image_h);

    /* Set a camera here */
    cam.open();
    cam.set_eye(eye);
    cam.set_fovx(fov);
    cam.update(point_t(image_w, image_h));

    /* Allocate the device arrays */
    const uint32_t tri_size = (uint32_t) app_kd_tree.tri_n * sizeof(app_kd_tree.acc[0]);
    const uint32_t kd_id_size = app_kd_tree.id_n * sizeof(app_kd_tree.ids[0]);
    const uint32_t kd_node_size = app_kd_tree.node_n * sizeof(app_kd_tree.root[0]);
    const uint32_t normal_size = app_kd_tree.tri_n * sizeof(normals[0]);
    CUDA_SAFE_CALL(cudaMalloc(&device_rt_tris, tri_size));
    CUDA_SAFE_CALL(cudaMalloc(&device_kd_ids, kd_id_size));
    CUDA_SAFE_CALL(cudaMalloc(&device_kd_nodes, kd_node_size));
    CUDA_SAFE_CALL(cudaMalloc(&device_normal, normal_size));
    CUDA_SAFE_CALL(cudaMemcpy(device_kd_nodes, &app_kd_tree.root[0], kd_node_size, cudaMemcpyHostToDevice));
    CUDA_SAFE_CALL(cudaMemcpy(device_kd_ids, &app_kd_tree.ids[0], kd_id_size, cudaMemcpyHostToDevice));
    CUDA_SAFE_CALL(cudaMemcpy(device_rt_tris, &app_kd_tree.acc[0], tri_size, cudaMemcpyHostToDevice));
    CUDA_SAFE_CALL(cudaMemcpy(device_normal, &normals[0], normal_size, cudaMemcpyHostToDevice));

    /* Bind all textures */
    node_texture_bind(device_kd_nodes, kd_node_size);
    rt_tri_texture_bind(device_rt_tris, tri_size);
    id_texture_bind(device_kd_ids, kd_id_size);
    normal_texture_bind(device_normal, normal_size);

    /* Start rendering main loop */
    glutMainLoop();
}

/***************************************************************************//**
 * Initialize GL
/******************************************************************************/
CUTBoolean init_gl()
{
    glewInit();
    if (! glewIsSupported( "GL_VERSION_2_0 " "GL_ARB_pixel_buffer_object "
                           "GL_EXT_framebuffer_object ")) {
        fprintf(stderr, "ERROR: Support for necessary OpenGL extensions missing.");
        fflush(stderr);
        return CUTFalse;
    }
    return CUTTrue;
}

/***************************************************************************//**
 * Create the output PBO
/******************************************************************************/
void create_pbo(GLuint* pbo)
{
    num_texels = image_w * image_h;
    num_values = num_texels * 4;
    size_tex_data = sizeof(GLubyte) * num_values;
    void *data = malloc(size_tex_data);
    assert(data != NULL);
    glGenBuffers(1, pbo);
    glBindBuffer(GL_ARRAY_BUFFER, *pbo);
    glBufferData(GL_ARRAY_BUFFER, size_tex_data, data, GL_DYNAMIC_DRAW);
    free(data);
    glBindBuffer(GL_ARRAY_BUFFER, 0);
    pbo_register(*pbo);
    CUT_CHECK_ERROR_GL();
}

/***************************************************************************//**
 * Delete the output PBO
/******************************************************************************/
void delete_pbo(GLuint* pbo)
{
    glBindBuffer(GL_ARRAY_BUFFER, *pbo);
    glDeleteBuffers(1, pbo);
    CUT_CHECK_ERROR_GL();
    *pbo = 0;
}

/***************************************************************************//**
 * Process cuda
/******************************************************************************/
void process_image()
{
    /* Update the camera */
    cam.open();
    cam.set_eye(eye);
    cam.set_fovx(fov);
    cam.update(point_t(image_w, image_h));

    /* Update the light source */
    const float dt = (float) sys::laps_t::to_time(sys::laps_t::get()) * 0.001f;
    lpos[2] = .5f * cosf(dt);
    lpos[0] = 1.f + .5f * sinf(dt);
    lpos[1] = 2.5f;

    switch(display_mode) {
        case shadowed: cuda_shadowed(pbo_dest, &cam, &aabb, lpos, image_w, image_h); break;
        case unshadowed: cuda_unshadowed(pbo_dest, &cam, &aabb, lpos, image_w, image_h); break;
        case noshading: cuda_noshading(pbo_dest, &cam, &aabb, lpos, image_w, image_h); break;
    };

    glBindBuffer(GL_PIXEL_UNPACK_BUFFER_ARB, pbo_dest);
    glBindTexture(GL_TEXTURE_2D, tex_screen);
    glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, 
                    image_w, image_h, 
                    GL_BGRA, GL_UNSIGNED_BYTE, NULL);

    CUT_CHECK_ERROR_GL();
}

/***************************************************************************//**
 * Display the result
/******************************************************************************/
void display_image()
{
    glDisable(GL_DEPTH_TEST);
    glDisable(GL_LIGHTING);
    glEnable(GL_TEXTURE_2D);
    glTexEnvf(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_REPLACE);
    glMatrixMode(GL_PROJECTION);
    glPushMatrix();
    glLoadIdentity();
    glOrtho(-1.0, 1.0, -1.0, 1.0, -1.0, 1.0);
    glMatrixMode( GL_MODELVIEW);
    glLoadIdentity();
    glViewport(0, 0, window_width, window_height);

    glBegin(GL_QUADS);
        glTexCoord2f(0.0, 1.0);
        glVertex3f(-1.0, -1.0, 0.5);
        glTexCoord2f(1.0, 1.0);
        glVertex3f(1.0, -1.0, 0.5);
        glTexCoord2f(1.0, 0.0);
        glVertex3f(1.0, 1.0, 0.5);
        glTexCoord2f(0.0, 0.0);
        glVertex3f(-1.0, 1.0, 0.5);
    glEnd();

    glMatrixMode(GL_PROJECTION);
    glPopMatrix();
    glDisable(GL_TEXTURE_2D);
    glBindBuffer(GL_PIXEL_PACK_BUFFER_ARB, 0);
    glBindBuffer(GL_PIXEL_UNPACK_BUFFER_ARB, 0);
    CUT_CHECK_ERROR_GL();
}

/***************************************************************************//**
 * Display callback
/******************************************************************************/
void display()
{
    sys::laps_t laps;
    if (enable_cuda) {
        process_image();
        display_image();
    }
    glutSwapBuffers();
    const uint64_t elapsed = laps.elapsed();
    float dt((float) sys::laps_t::to_time(elapsed));
    char msg[256];
    sprintf(msg, "CUDA ray tracer -- %2.2f MRay / s", image_w * image_h / 1000.f / dt);
    glutSetWindowTitle(msg);
}

/***************************************************************************//**
 * Idle function
/******************************************************************************/
void idle()
{
    glutPostRedisplay();
}

/***************************************************************************//**
 * Keyboard handler
/******************************************************************************/
void keyboard(unsigned char key, int /*x*/, int /*y*/)
{
    switch(key) {
        case(27) :
            CUDA_SAFE_CALL(cudaFree(device_rt_tris));
            CUDA_SAFE_CALL(cudaFree(device_kd_ids));
            CUDA_SAFE_CALL(cudaFree(device_kd_nodes));
            CUDA_SAFE_CALL(cudaFree(device_normal));
            delete_pbo(&pbo_dest);
            delete_texture(&tex_screen);
            sys::mem::liberate((void *) app_kd_tree.acc);
            sys::mem::liberate((void *) app_kd_tree.ids);
            sys::mem::liberate((void *) app_kd_tree.root);
            exit(0);
        case ' ': enable_cuda ^= 1; break;
        case '1': display_mode = shadowed; break;
        case '2': display_mode = unshadowed; break;
        case '3': display_mode = noshading; break;
        case 'w': eye.z -= 0.1f; break;
        case 's': eye.z += 0.1f; break;
        case 'd': eye.x += 0.1f; break;
        case 'a': eye.x -= 0.1f; break;
        case 'r': eye.y += 0.1f; break;
        case 'f': eye.y -= 0.1f; break;
    }
}

/***************************************************************************//**
 * Reshape the window
/******************************************************************************/
void reshape(int w, int h)
{
    window_width = w;
    window_height = h;
}

/***************************************************************************//**
 * Main menu
/******************************************************************************/
void main_menu(int i)
{
    keyboard((unsigned char) i, 0, 0);
}

/***************************************************************************//**
 * Delete the texture
/******************************************************************************/
void delete_texture(GLuint* tex)
{
    glDeleteTextures(1, tex);
    CUT_CHECK_ERROR_GL();
    *tex = 0;
}

/***************************************************************************//**
 * Create the texture
/******************************************************************************/
void create_texture( GLuint* tex_name, unsigned int size_x, unsigned int size_y)
{
    glGenTextures(1, tex_name);
    glBindTexture(GL_TEXTURE_2D, *tex_name);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, size_x, size_y, 0, GL_RGBA,
                 GL_UNSIGNED_BYTE, NULL);
    CUT_CHECK_ERROR_GL();
}
